import torchvision.transforms as transforms
from datasets import load_dataset, concatenate_datasets

from main.dataset import *

def create_dataloder(args, is_train):

    dataset = get_dataset_base(args.dataset_path, args.dataset, is_train=is_train)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((512, 512))
    ])

    if args.use_cached_latents:
        def transforms_all(examples):
            examples['img_tensor'] = [transform(image) for image in examples["image"]]
            examples['latents_tensor'] = torch.squeeze(torch.tensor(examples['latents']), 0)
            gt_patch_real = torch.tensor(examples['gt_patch_real'])
            gt_patch_imag = torch.tensor(examples['gt_patch_imag'])
            examples['gt_patch_tensor'] = torch.complex(gt_patch_real, gt_patch_imag)
            examples['watermarking_mask_tensor'] = torch.squeeze(torch.tensor(examples['watermarking_mask']), 0)
            return examples
        transformed_dataset = dataset.with_transform(transforms_all, columns=['image', 'latents', \
                                                    'gt_patch_real', 'gt_patch_imag', 'watermarking_mask'])
    else:
        transformed_dataset = None
        raise NotImplementedError()
        # empty_text_embeddings = pipe.get_text_embedding(prompt)
        # init_latents_approx = get_init_latent(gt_img_tensor, pipe, empty_text_embeddings)
    
    def collate_fn(examples):
            pixel_values = torch.stack([example["img_tensor"] for example in examples])
            pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
            latents = torch.stack([example["latents_tensor"] for example in examples])
            gt_patch = torch.stack([example["gt_patch_tensor"] for example in examples])
            watermarking_mask = torch.stack([example["watermarking_mask_tensor"] for example in examples])

            return {"image": pixel_values, "latents": latents, 'gt_patch': gt_patch, \
                    'watermarking_mask': watermarking_mask}
    if is_train:
        dataloader = DataLoader(transformed_dataset.select(range(args.num_images)), batch_size=args.batch_size, \
                            collate_fn=collate_fn, num_workers=args.num_workers)
    else:
        dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, \
                            collate_fn=collate_fn, num_workers=args.num_workers)

    return dataloader

def create_dataloder_all(args, is_train):

    dataset_diffusiondb = get_dataset_base(args.dataset_path, 'diffusiondb', is_train=is_train)
    dataset_coco = get_dataset_base(args.dataset_path, 'coco', is_train=is_train)
    dataset_wiki = get_dataset_base(args.dataset_path, 'wikiart', is_train=is_train)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((512, 512))
    ])

    if args.use_cached_latents:
        def transforms_all(examples):
            examples['img_tensor'] = [transform(image) for image in examples["image"]]
            examples['latents_tensor'] = torch.squeeze(torch.tensor(examples['latents']), 0)
            gt_patch_real = torch.tensor(examples['gt_patch_real'])
            gt_patch_imag = torch.tensor(examples['gt_patch_imag'])
            examples['gt_patch_tensor'] = torch.complex(gt_patch_real, gt_patch_imag)
            examples['watermarking_mask_tensor'] = torch.squeeze(torch.tensor(examples['watermarking_mask']), 0)
            return examples
        transformed_dataset_diffusiondb = dataset_diffusiondb.with_transform(transforms_all, columns=['image', 'latents', \
                                                    'gt_patch_real', 'gt_patch_imag', 'watermarking_mask'])
        transformed_dataset_coco = dataset_coco.with_transform(transforms_all, columns=['image', 'latents', \
                                                    'gt_patch_real', 'gt_patch_imag', 'watermarking_mask'])
        transformed_dataset_wiki = dataset_wiki.with_transform(transforms_all, columns=['image', 'latents', \
                                                    'gt_patch_real', 'gt_patch_imag', 'watermarking_mask'])
        if is_train:
            transformed_dataset = concatenate_datasets([transformed_dataset_diffusiondb.select(range(args.num_images)), transformed_dataset_coco.select(range(args.num_images)), transformed_dataset_wiki.select(range(args.num_images))])
        else:
            transformed_dataset = concatenate_datasets([transformed_dataset_diffusiondb, transformed_dataset_coco, transformed_dataset_wiki])

    else:
        transformed_dataset = None
        raise NotImplementedError()
        # empty_text_embeddings = pipe.get_text_embedding(prompt)
        # init_latents_approx = get_init_latent(gt_img_tensor, pipe, empty_text_embeddings)
    
    def collate_fn(examples):
            pixel_values = torch.stack([example["img_tensor"] for example in examples])
            pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
            latents = torch.stack([example["latents_tensor"] for example in examples])
            gt_patch = torch.stack([example["gt_patch_tensor"] for example in examples])
            watermarking_mask = torch.stack([example["watermarking_mask_tensor"] for example in examples])

            return {"image": pixel_values, "latents": latents, 'gt_patch': gt_patch, \
                    'watermarking_mask': watermarking_mask}
    if is_train:
        dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, \
                            collate_fn=collate_fn, num_workers=args.num_workers)
    else:
        dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, \
                            collate_fn=collate_fn, num_workers=args.num_workers)

    return dataloader
